/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_BASIC_BLOCK_MATMUL_CUSTOM_IMPL_H
#define EXAMPLES_MATRIX_BASIC_BLOCK_MATMUL_CUSTOM_IMPL_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"

constexpr MatmulConfig MM_CFG = GetBasicConfig(128, 256, 64); // baseM, baseN, baseK

struct BasicBlockMatrixOffset {
    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
};

template <typename aType, typename bType, typename cType, typename biasType>
class BasicBlockMatmulKernel {
    public:
        __aicore__ inline BasicBlockMatmulKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
        template <bool hasBias = false>
        __aicore__ inline void Process(AscendC::TPipe* pipe);
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType, true>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, MM_CFG> matmulObj;

    private:
        __aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, BasicBlockMatrixOffset& matrixOffset,
                                          bool isAtrans, bool isBtrans);

        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template <typename aType, typename bType, typename cType, typename biasType>
__aicore__ inline void BasicBlockMatmulKernel<aType, bType, cType, biasType>::Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c,
                                                                         GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), tiling.M * tiling.Ka);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), tiling.Kb * tiling.N);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), tiling.M * tiling.N);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), tiling.N);

    struct BasicBlockMatrixOffset matrixOffset;
    bool isAtrans = true;
    bool isBtrans = false;
    CalcOffset(AscendC::GetBlockIdx(), tiling, matrixOffset, isAtrans, isBtrans);
    aGlobal = aGlobal[matrixOffset.offsetA];
    bGlobal = bGlobal[matrixOffset.offsetB];
    cGlobal = cGlobal[matrixOffset.offsetC];
    biasGlobal = biasGlobal[matrixOffset.offsetBias];
    if(GetSysWorkSpacePtr() == nullptr){
        return;
    }
}

template <typename aType, typename bType, typename cType, typename biasType>
template <bool hasBias>
__aicore__ inline void BasicBlockMatmulKernel<aType, bType, cType, biasType>::Process(AscendC::TPipe* pipe)
{
    matmulObj.SetTensorA(aGlobal, true); // A matrix transpose
    matmulObj.SetTensorB(bGlobal);
    if constexpr (hasBias) {
        matmulObj.SetBias(biasGlobal);
    }
    matmulObj.IterateAll(cGlobal);
    matmulObj.End();
}

__aicore__ inline uint32_t Ceiling(uint32_t a, uint32_t b)
{
    if (b == 0) {
        return 0;
    }
    return (a + b - 1) / b;
}

template <typename aType, typename bType, typename cType, typename biasType>
__aicore__ inline void BasicBlockMatmulKernel<aType, bType, cType, biasType>::CalcOffset(int32_t blockIdx, const TCubeTiling& tiling,
    BasicBlockMatrixOffset& matrixOffset, bool isAtrans, bool isBtrans)
{
    auto mSingleBlocks = Ceiling(tiling.M, tiling.singleCoreM);
    auto mCoreIndx = blockIdx % mSingleBlocks;
    auto nCoreIndx = blockIdx / mSingleBlocks;

    matrixOffset.offsetA = mCoreIndx * tiling.Ka * tiling.singleCoreM;
    if (isAtrans) {
        matrixOffset.offsetA = mCoreIndx * tiling.singleCoreM;
    }
    matrixOffset.offsetB = nCoreIndx * tiling.singleCoreN;
    if (isBtrans) {
        matrixOffset.offsetB = nCoreIndx * tiling.Kb * tiling.singleCoreN;
    }
    matrixOffset.offsetC = mCoreIndx * tiling.N * tiling.singleCoreM + nCoreIndx * tiling.singleCoreN;
    matrixOffset.offsetBias = nCoreIndx * tiling.singleCoreN;
}
#endif // EXAMPLES_MATRIX_BASIC_BLOCK_MATMUL_CUSTOM_IMPL_H
